/**
 * Copyright (c) 2016-2021, Bosco.Liao (bosco_liao@126.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.util;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;

/**
 * Md5Utils, Thread safe.
 *
 * @author Bosco.Liao
 */
public final class Md5Utils {

	private static final Charset UTF_8 = Charset.forName("UTF-8");

	private static final String ALGORITHM_NAME = "MD5";

	private static final int DEFAULT_ITERATIONS = 1 << 0;
	
	private static final int DEFAULT_BUFFER_SIZE = 1024;

	private static final char[] DIGITS = {
			'0', '1', '2', '3', '4', '5', '6', '7', 
			'8', '9', 'a', 'b', 'c', 'd', 'e', 'f' 
	};

	private Md5Utils() {
		throw new InstantiationError("Utility class must not be instantiated.");
	}
	
	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @return MD5 hex string.
	 */
	public static String getMd5(final Object source) {
		return getMd5ToHex(source, null, DEFAULT_ITERATIONS);
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @param salt
	 * @return MD5 hex string.
	 */
	public static String getMd5(final Object source, final Object salt) {
		return getMd5ToHex(source, salt, DEFAULT_ITERATIONS);
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @param salt
	 * @param hashIterations
	 * @return MD5 hex string.
	 */
	public static String getMd5(final Object source, final Object salt, final int hashIterations) {
		return getMd5ToHex(source, salt, hashIterations);
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @return MD5 hex string.
	 */
	public static String getMd5ToHex(final Object source) {
		return getMd5ToHex(source, null, DEFAULT_ITERATIONS);
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @param salt
	 * @return MD5 hex string.
	 */
	public static String getMd5ToHex(final Object source, final Object salt) {
		return getMd5ToHex(source, salt, DEFAULT_ITERATIONS);
	}
	
	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @return MD5 base64 string.
	 */
	public static String getMd5ToBase64(final Object source) {
		return getMd5ToBase64(source, null, DEFAULT_ITERATIONS);
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @param salt
	 * @return MD5 base64 string.
	 */
	public static String getMd5ToBase64(final Object source, final Object salt) {
		return getMd5ToBase64(source, salt, DEFAULT_ITERATIONS);
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @param salt
	 * @param hashIterations
	 * @return MD5 hex string.
	 */
	public static String getMd5ToHex(final Object source, final Object salt, final int hashIterations) {
		return toHexString(md5HashByObject(source, salt, hashIterations));
	}

	/**
	 * Objects are supported: byte[], char[], String, File or InputStream.
	 * 
	 * @param source
	 * @param salt
	 * @param hashIterations
	 * @return MD5 base64 string.
	 */
	public static String getMd5ToBase64(final Object source, final Object salt, final int hashIterations) {
		return toBase64String(md5HashByObject(source, salt, hashIterations));
	}

	public static final String toBase64String(final byte[] source) {
		return new String(Base64.getEncoder().encode(source), UTF_8);
	}

	public static final String toHexString(final byte[] source) {
		int len = source.length;
		char[] out = new char[len << 1];
		for (int i = 0, j = 0; i < len; i++) {
			out[j++] = DIGITS[(0xF0 & source[i]) >>> 4];
			out[j++] = DIGITS[0x0F & source[i]];
		}
		return new String(out);
	}

	public final static byte[] toBytes(InputStream in) throws IOException {
		if (in == null) {
			throw new IllegalArgumentException("InputStream must not be null.");
		}
		ByteArrayOutputStream out = new ByteArrayOutputStream(DEFAULT_BUFFER_SIZE);
		byte[] buffer = new byte[DEFAULT_BUFFER_SIZE];
		int bytesRead;
		try {
			while ((bytesRead = in.read(buffer)) != -1) {
				out.write(buffer, 0, bytesRead);
			}
			return out.toByteArray();
		} finally {
			closeQuietly(in);
		}
	}

	private static byte[] md5HashByObject(Object source, Object salt, int hashIterations) {
		if (source == null) {
			throw new IllegalArgumentException("Argument[source] cannot be null.");
		}
		int iterations = Math.max(DEFAULT_ITERATIONS, hashIterations);
		byte[] sourceBytes = toBytes(source);
		byte[] saltBytes = salt == null ? null : toBytes(salt);
		return md5Hash(sourceBytes, saltBytes, iterations);
	}

	private static byte[] md5Hash(byte[] source, byte[] salt, int hashIterations) {
		MessageDigest digest = getMd5Digest();
		if (salt != null && salt.length > 0) {
			digest.reset();
			digest.update(salt);
		}
		byte[] hashed = digest.digest(source);
		for (int i = 1; i < hashIterations; i++) {// already hashed once above.
			digest.reset();
			hashed = digest.digest(hashed);
		}
		return hashed;
	}

	private static MessageDigest getMd5Digest() {
		try {
			return MessageDigest.getInstance(ALGORITHM_NAME);
		} catch (NoSuchAlgorithmException e) {
			throw new UnknownAlgorithmException("No native 'MD5' MessageDigest instance available on the current JVM.");
		}
	}

	private static byte[] toBytes(Object source) {
		if (source == null) {
			throw new IllegalArgumentException("Argument for byte conversion cannot be null.");
		}
		if (!isByteSource(source)) {
			throw new CodecException(
					"Only the following types of objects are supported: byte[], char[], String, File or InputStream.");
		}
		try {
			if (source instanceof char[]) {
				return new String((char[]) source).getBytes(UTF_8);
			} else if (source instanceof String) {
				return ((String) source).getBytes(UTF_8);
			} else if (source instanceof File) {
				return toBytes(new FileInputStream((File) source));
			} else if (source instanceof InputStream) {
				return toBytes((InputStream) source);
			} else {
				return (byte[]) source;
			}
		} catch (Exception e) {
			throw new CodecException("Unable to convert source [" + source + "] to byte array", e);
		}
	}

	/**
	 * Check is MD5 hash support type.
	 */
	private static boolean isByteSource(Object source) {
		return source instanceof byte[] || source instanceof char[] || source instanceof String
				|| source instanceof File || source instanceof InputStream;
	}

	/**
	 * Close and don't throw.
	 */
	private static void closeQuietly(Closeable closeable) {
		try {
			if (closeable != null) {
				closeable.close();
			}
		} catch (IOException ioe) {
			// ignore
		}
	}

	public static class CodecException extends RuntimeException {

		private static final long serialVersionUID = 1411948854906245432L;

		public CodecException(String message) {
			super(message);
		}

		public CodecException(String message, Throwable cause) {
			super(message, cause);
		}
	}

	public static class UnknownAlgorithmException extends RuntimeException {

		private static final long serialVersionUID = -9057567450240331588L;

		public UnknownAlgorithmException(String message) {
			super(message);
		}

		public UnknownAlgorithmException(String message, Throwable cause) {
			super(message, cause);
		}
	}

}
